import os
import sys
import redis
import copy
import logging
import networkx as nx

from database.rdb_client import *
from base.defines import CompileType, Attribute, FileType
from utils.singleton import Singleton
from command_parser.gcc_parse.gcc_command_parser import GCCParser
from command_parser.clang_parse.clang_command_parser import ClangParser
from command_parser.ld_parse.ld_command_parser import LDParser
from command_parser.ar_parse.ar_command_parser import ARParser
from utils.db_object_pool import DBObjectPool
from utils.file_type_detect import is_assembly, is_source, is_staticlib, is_sharedlib


class BuildGraph(metaclass=Singleton):
    
    def __init__(self):
        
        
        self.build_graph = nx.DiGraph()
        self.node_attr_map = {}
        
        self.root_nodes = []
        
        # compilers_db_handler
        self.gcc_db_handler = None
        self.clang_db_handler = None
        self.ld_db_handler = None
        self.ar_db_handler = None
        self.ranlib_db_handler = None
        self.strip_db_handler = None
        
        self.initized = False
    
    def is_initized(self):
        return self.initized
                
    def construct_build_graph(self):
        
        self.gcc_db_handler = DBObjectPool.get_db_inst_by_name("gcc")
        self.clang_db_handler = DBObjectPool.get_db_inst_by_name("clang")
        self.ld_db_handler = DBObjectPool.get_db_inst_by_name("ld")
        self.ar_db_handler = DBObjectPool.get_db_inst_by_name("ar")
        self.ranlib_db_handler = DBObjectPool.get_db_inst_by_name("ranlib")
        self.strip_db_handler = DBObjectPool.get_db_inst_by_name("strip")
        
        self.remove_useless_dependence()
        self.generate_build_graph()
        self.update_nodes_attribute()
        
        self.initized = True
        
    def read_attribute(self, node="", attribute=-1):
        if not node or attribute == -1:
            return None
        
        node_attr = self.node_attr_map.get(node, None)
        
        if not node_attr:
            return None
        
        return node_attr.get(attribute, None)
    
    def write_attribute(self, node="", attribute=-1, value=None):
        if not node or attribute == -1:
            return
        
        node_attr = self.node_attr_map.get(node, None)
        if not node_attr:
            return
        
        node_attr[attribute] = value

    def update_nodes_attribute(self):
        def travere_graph_recursively(node=None, parent=None):
            succussors = self.build_graph.successors(node)
            for succ in succussors:
                dep = self.read_attribute(node=succ, attribute=Attribute.IN_DEGREE)
                if not dep:
                    dep = 1
                else:
                    dep = dep + 1
                self.write_attribute(node=succ, attribute=Attribute.IN_DEGREE, value=dep)
                travere_graph_recursively(node=succ, parent=node)
            
            if not succussors:
                if not parent:
                    assert False
                
                db_handler = self.read_attribute(node=parent, attribute=Attribute.DB_HANDLER)
                parent_type = self.read_attribute(node=parent, attribute=Attribute.FILETYPE)
                # parser_class = self.read_attribute(node=parent, attribute=Attribute.PARSER)
                my_type = FileType.UNKNOWN
                if parent_type == FileType.COMPILE:
                    name_dep = db_handler.get_name_dep()
                    file_name = name_dep.get(node, None)
                    if is_assembly(file_name):
                        my_type = FileType.ASSEMBLE
                    if is_source(file_name):
                        my_type = FileType.SOURCE
                attr_map = dict()
                attr_map[Attribute.DB_HANDLER] = db_handler
                attr_map[Attribute.FILETYPE] = my_type
                attr_map[Attribute.IN_DEGREE] = 1
                # attr_map[Attribute.PARSER] = parser_class
                self.node_attr_map[node] = attr_map
        
        root_nodes = self.get_root_nodes()
        for node in root_nodes:
            self.write_attribute(node=node, attribute=Attribute.IN_DEGREE, value=0)
            travere_graph_recursively(node=node, parent=None)        
        
    def generate_build_graph(self):
        
        handler_list = [(CompileType.GCC, self.gcc_db_handler), 
                        (CompileType.CLANG, self.clang_db_handler), 
                        (CompileType.LD, self.ld_db_handler), 
                        (CompileType.AR, self.ar_db_handler)]
        
        
        parser_dict = { CompileType.GCC: GCCParser,
                        CompileType.CLANG: ClangParser,
                        CompileType.LD: LDParser,
                        CompileType.AR: ARParser }
        
        for (compile_type, handler) in handler_list:
            
            name_dep = handler.get_name_dep()
            link_dep = handler.get_link_dep()
            cmd_dep = handler.get_cmd_dep()
            
            for dst_md5 in link_dep.keys():
                if not dst_md5:
                    continue
                node_attr = {}
                [cwd, command] = cmd_dep.get(dst_md5, [None, []])
                deps = link_dep.get(dst_md5, None)
                if not cwd and not command:
                    continue
                
                node_attr[Attribute.NAME] = name_dep.get(dst_md5)
                node_attr[Attribute.CWD] = cwd
                node_attr[Attribute.COMPILER] = compile_type
                node_attr[Attribute.COMMAND] = command
                node_attr[Attribute.DB_HANDLER] = handler
                node_attr[Attribute.PARSER] = parser_dict.get(compile_type)
                node_attr[Attribute.OUT_DEGREE] = len(link_dep.get(dst_md5, []))
                node_attr[Attribute.IN_DEGREE] = 0 # update later
                node_attr[Attribute.FILETYPE] = FileType.COMPILE if compile_type == CompileType.CLANG or compile_type == CompileType.GCC \
                                                else FileType.LINK if compile_type == CompileType.AR or compile_type == CompileType.LD else FileType.UNKNOW
                
                self.node_attr_map[dst_md5] = node_attr
                
                for src_md5 in link_dep.get(dst_md5, []):
                    self.build_graph.add_edge(dst_md5, src_md5) 
    
    def remove_useless_dependence(self):
        
        ranlib_link_dep = self.ranlib_db_handler.get_link_dep()
        strip_link_dep = self.strip_db_handler.get_link_dep()
        
        # 先处理ranlib的情况  ranlib: dst3:src2; strip: dst2:src1. src2 = dst2
        # -> update ranlib: dst3:src1
        # -> remove strip: dst2:src1
        # 暂时先不同步数据库里的数据
         
        for dst3_md5, src2_md5 in list(ranlib_link_dep.items()):
            src1_md5 = strip_link_dep.get(src2_md5, None)
            if not src1_md5:
                continue
            ranlib_link_dep[dst3_md5] = src1_md5
            strip_link_dep.pop(src2_md5)
            self.ranlib_db_handler.update_link_dep(dst3_md5, src1_md5)
            self.strip_db_handler.del_link_dep(src2_md5)
        
        # 先处理strip的情况  strip: dst3:src2; ranlib: dst2:src1. src2 = dst2
        # -> update strip: dst3:src1
        # -> remove ranlib: dst2:src1
        # 暂时先不同步数据库里的数据
        for dst3_md5, src2_md5 in list(strip_link_dep.items()):
            src1_md5 = ranlib_link_dep.get(src2_md5, None)
            if not src1_md5:
                continue
            strip_link_dep[dst3_md5] = src1_md5
            ranlib_link_dep.pop(src2_md5)
            self.strip_db_handler.update_link_dep(dst3_md5, src1_md5)
            self.ranlib_db_handler.del_link_dep(src2_md5)
        
        
        # 删除ranlib和strip的依赖
        self.remove_redundant_dependence(dst_link_dep=ranlib_link_dep, dst_handler=self.ranlib_db_handler)
        self.remove_redundant_dependence(dst_link_dep=strip_link_dep, dst_handler=self.strip_db_handler)
        
    def remove_redundant_dependence(self, dst_link_dep={}, dst_handler=None):
        
        gcc_name_dep = self.gcc_db_handler.get_name_dep()
        gcc_link_dep = self.gcc_db_handler.get_link_dep()
        gcc_cmd_dep = self.gcc_db_handler.get_cmd_dep()
        
        clang_name_dep = self.clang_db_handler.get_name_dep()
        clang_link_dep = self.clang_db_handler.get_link_dep()
        clang_cmd_dep = self.clang_db_handler.get_cmd_dep()
        
        ld_name_dep = self.ld_db_handler.get_name_dep()
        ld_link_dep = self.ld_db_handler.get_link_dep()
        ld_cmd_dep = self.ld_db_handler.get_cmd_dep()
        
        ar_name_dep = self.ar_db_handler.get_name_dep()
        ar_link_dep = self.ar_db_handler.get_link_dep()
        ar_cmd_dep = self.ar_db_handler.get_cmd_dep()
        
        
        # 对 gcc/clang ... 处理  dst_compiler: dst3: src2 gcc: dst2:src1[, src0]
        # -> update: gcc: dst3: src1[, src0]
        # -> remove: dst_compiler: dst3: src2
        
        dst_link_dep_bak = copy.deepcopy(dst_link_dep)
        
        for dst3_md5, src2_md5 in dst_link_dep.items():
            if src2_md5 in gcc_link_dep.keys():
                self.update_dependency(dst_md5=dst3_md5, src_md5=src2_md5, src_handler=self.gcc_db_handler, dst_handler=dst_handler, src_link_dep=gcc_link_dep, src_cmd_dep=gcc_cmd_dep)

            if src2_md5 in clang_link_dep.keys():
                self.update_dependency(dst_md5=dst3_md5, src_md5=src2_md5, src_handler=self.clang_db_handler, dst_handler=dst_handler, src_link_dep=clang_link_dep, src_cmd_dep=clang_cmd_dep)
            
            if src2_md5 in ld_link_dep.keys():
                self.update_dependency(dst_md5=dst3_md5, src_md5=src2_md5, src_handler=self.ld_db_handler, dst_handler=dst_handler, src_link_dep=ld_link_dep, src_cmd_dep=ld_cmd_dep)
            
            if src2_md5 in ar_link_dep.keys():
                self.update_dependency(dst_md5=dst3_md5, src_md5=src2_md5, src_handler=self.ar_db_handler, dst_handler=dst_handler, src_link_dep=ar_link_dep, src_cmd_dep=ar_cmd_dep)
    
    def update_dependency(self, dst_md5="", src_md5="", src_handler=None, dst_handler=None, src_link_dep={}, src_cmd_dep={}):
        
        # 这里的处理不是特别合适，因为src->gcc->obj obj->ld->bin obj->ranlib->lib 我们的处理会使重新构建bin找不到源文件
        link_value = src_link_dep.get(src_md5)
        cmd_value = src_cmd_dep.get(src_md5)
        
        dst_handler.del_link_dep(dst_md5)
        src_handler.del_name_dep(src_md5)
        
        src_handler.del_link_dep(src_md5)
        src_handler.set_link_dep(dst_md5, link_value)
        
        src_handler.del_cmd_dep(src_md5)
        src_handler.set_cmd_dep(dst_md5, cmd_value)
        
    def get_reversed_graph_with_degree(self):
        
        reversed_graph = self.build_graph.reverse(copy=True)
        
        for node, in_degree in reversed_graph.in_degree_iter():
            reversed_graph.add_node(node, {node: in_degree})
        
        return reversed_graph
    
    def get_root_nodes(self):
        if not self.root_nodes:
            self.root_nodes = [node for node, in_degree in self.build_graph.in_degree_iter() if in_degree == 0]
        
        return self.root_nodes
    
    def graph_dump(self, dump_file=""):
        
        if not dump_file:
            logging.warning("Do not specify build graph dump file")
            return
        
        graph_copy = self.build_graph.copy()
        
        for node in graph_copy.nodes():
            db_handler = self.read_attribute(node=node, attribute=Attribute.DB_HANDLER)
            if not db_handler:
                continue
            
            name_dep = db_handler.get_name_dep()
            node_name = name_dep.get(node, None)
            if not node_name:
                continue
            
            node_type = self.read_attribute(node=node, attribute=Attribute.FILETYPE)
            if node_type == FileType.SOURCE:
                graph_copy.add_node(node, label=os.path.basename(node_name),
                                    comment="source")
            elif node_type == FileType.LINK:
                graph_copy.add_node(node, label=os.path.basename(node_name),
                                    comment="link")
            elif node_type == FileType.COMPILE:
                graph_copy.add_node(node, label=os.path.basename(node_name), 
                                    comment="compile")
            else:
                graph_copy.add_node(node, label=os.path.basename(node_name),
                                    comment='unknown')
        
        draw_graph = nx.drawing.nx_agraph.to_agraph(graph_copy)
        draw_graph.graph_attr['concentrate'] = True
        draw_graph.graph_attr['splines'] = 'curved'
        draw_graph.graph_attr['overlap'] = 'prism'
        
        for node in draw_graph.nodes():
            if node.attr["comment"] == "source":
                node.attr['fillcolor'] = 'green'
                node.attr['shape'] = 'diamond'
                node.attr['style'] = 'filled'
            
            elif node.attr["comment"] == "link":
                node.attr['fillcolor'] = 'blue'
                node.attr['shape'] = 'box'
                node.attr['style'] = 'filled'
            
            elif node.attr["comment"] == "compile":
                node.attr['fillcolor'] = 'orange'
                node.attr['shape'] = 'circle'
                node.attr['style'] = 'filled'
            
            else:
                node.attr['fillcolor'] = 'red'
                node.attr['shape'] = 'tripleoctagon'
                node.attr['style'] = 'filled'
        
        __console__ = sys.stderr
        try:
            with open(os.devnull, "w") as f:
                sys.stderr = f
                draw_graph.draw(dump_file, prog='dot', format='png')
        except AttributeError as e:
            print("[ERROR] error occured when draw graph")
            logging.exception("It occurs error during drawing dot, message: %s" % str(e))
        finally:
            sys.stderr = __console__
        
        return True
          
        
        
        
                
        
        
